# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Created by: RainbowSecret
# Modified from: https://github.com/AlexHex7/Non-local_pytorch
# Microsoft Research
# yuyua@microsoft.com
# Copyright (c) 2018
##
# This source code is licensed under the MIT-style license found in the
# LICENSE file in the root directory of this source tree
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import torch.nn as nn
from torch.nn import functional as F
import math
import torch.utils.model_zoo as model_zoo
import torch
import os
import sys
import pdb
import numpy as np
from torch.autograd import Variable
import functools

torch_ver = torch.__version__[:3]

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, '../inplace_abn'))
from bn import InPlaceABNSync
BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')

from base_oc_block import BaseOC_Context_Module


class ASP_OC_Module(nn.Module):
    def __init__(self, features, out_features=512, dilations=(12, 24, 36)):
        super(ASP_OC_Module, self).__init__()
        self.context = nn.Sequential(
            nn.Conv2d(
                features,
                out_features,
                kernel_size=3,
                padding=1,
                dilation=1,
                bias=True),
            InPlaceABNSync(out_features),
            BaseOC_Context_Module(
                in_channels=out_features,
                out_channels=out_features,
                key_channels=out_features//2,
                value_channels=out_features,
                dropout=0,
                sizes=(
                    [2])))
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                features,
                out_features,
                kernel_size=1,
                padding=0,
                dilation=1,
                bias=False),
            InPlaceABNSync(out_features))
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                features,
                out_features,
                kernel_size=3,
                padding=dilations[0],
                dilation=dilations[0],
                bias=False),
            InPlaceABNSync(out_features))
        self.conv4 = nn.Sequential(
            nn.Conv2d(
                features,
                out_features,
                kernel_size=3,
                padding=dilations[1],
                dilation=dilations[1],
                bias=False),
            InPlaceABNSync(out_features))
        self.conv5 = nn.Sequential(
            nn.Conv2d(
                features,
                out_features,
                kernel_size=3,
                padding=dilations[2],
                dilation=dilations[2],
                bias=False),
            InPlaceABNSync(out_features))

        self.conv_bn_dropout = nn.Sequential(
            nn.Conv2d(
                out_features * 5,
                out_features,
                kernel_size=1,
                padding=0,
                dilation=1,
                bias=False),
            InPlaceABNSync(out_features),
            nn.Dropout2d(0.1))

    def _cat_each(self, feat1, feat2, feat3, feat4, feat5):
        assert(len(feat1) == len(feat2))
        z = []
        for i in range(len(feat1)):
            z.append(
                torch.cat(
                    (feat1[i],
                     feat2[i],
                        feat3[i],
                        feat4[i],
                        feat5[i]),
                    1))
        return z

    def forward(self, x):
        if isinstance(x, Variable):
            _, _, h, w = x.size()
        elif isinstance(x, tuple) or isinstance(x, list):
            _, _, h, w = x[0].size()
        else:
            raise RuntimeError('unknown input type')

        feat1 = self.context(x)
        feat2 = self.conv2(x)
        feat3 = self.conv3(x)
        feat4 = self.conv4(x)
        feat5 = self.conv5(x)

        if isinstance(x, Variable):
            out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
        elif isinstance(x, tuple) or isinstance(x, list):
            out = self._cat_each(feat1, feat2, feat3, feat4, feat5)
        else:
            raise RuntimeError('unknown input type')

        output = self.conv_bn_dropout(out)
        return output
